import os
import glob
from copy import deepcopy

import numpy as np

from collections.abc import Sequence
from torch.utils.data import Dataset
from Augmentations.Build_Augmentation import build_augmentation
from Utils.Logger import print_log
from Datasets.Build_Dataloader import datasets


@datasets.register_module()
class S3DIS(Dataset):
    VALID_ASSETS = [
        "coord",
        "color",
        "normal",
        "strength",
        "segment",
        "instance",
        "pose",
    ]

    def __init__(self, data_path, mode, split=('Area_1', 'Area_2', 'Area_3', 'Area_4', 'Area_6'), transform=None, loop=1):
        super(S3DIS, self).__init__()
        self.data_path = data_path
        self.split = split
        self.transforms = build_augmentation(transform)
        self.loop = loop
        self.mode = mode
        self.transform = transform

        self.data_list = self.get_data_list()
        print_log('The size of %s data is %d x %d' % (split, len(self.data_list), self.loop), logger='S3DIS')

    def get_data_list(self):
        if isinstance(self.split, str):
            data_list = glob.glob(os.path.join(self.data_path, self.split, "*"))
        elif isinstance(self.split, Sequence):
            data_list = []
            for split in self.split:
                data_list += glob.glob(os.path.join(self.data_path, split, "*"))
        else:
            raise NotImplementedError
        return data_list

    def get_data_name(self, idx):
        remain, room_name = os.path.split(self.data_list[idx])
        remain, area_name = os.path.split(remain)
        return f"{area_name}-{room_name}"

    def get_data(self, idx):
        name = self.get_data_name(idx % len(self.data_list))
        data_path = self.data_list[idx % len(self.data_list)]
        assets = os.listdir(data_path)
        data_dict = {}

        for asset in assets:
            if not asset.endswith(".npy"):
                continue
            if asset[:-4] not in self.VALID_ASSETS:
                continue
            data_dict[asset[:-4]] = np.load(os.path.join(data_path, asset))

        data_dict["name"] = name
        if self.mode == "train":
            data_dict["mode"] = self.mode
            for item in self.transform:
                if item.NAME == "SphereCrop":
                    data_dict["num_points"] = item.point_max

        if "coord" in data_dict.keys():
            data_dict["coord"] = data_dict["coord"].astype(np.float32)

        if "color" in data_dict.keys():
            data_dict["color"] = data_dict["color"].astype(np.float32)

        if "normal" in data_dict.keys():
            data_dict["normal"] = data_dict["normal"].astype(np.float32)

        if "segment" in data_dict.keys():
            data_dict["segment"] = data_dict["segment"].reshape([-1]).astype(np.int32)
        else:
            data_dict["segment"] = (
                    np.ones(data_dict["coord"].shape[0], dtype=np.int32) * -1
            )

        if "instance" in data_dict.keys():
            data_dict["instance"] = data_dict["instance"].reshape([-1]).astype(np.int32)
        else:
            data_dict["instance"] = (
                    np.ones(data_dict["coord"].shape[0], dtype=np.int32) * -1
            )

        return data_dict

    def __getitem__(self, idx):
        # load data
        data_dict = self.get_data(idx)
        data_dict = self.transforms(data_dict)
        return data_dict

    def __len__(self):
        return len(self.data_list) * self.loop
